/*
 * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights
 * reserved. SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "defines.h"
#include "mha.h"
#if IS_MLA
#include "barriers.cuh"
#include "mhaUtils.cuh"
#include "mha_components.cuh"
#include "mha_stdheaders.cuh"
#include "mla_sm120.cuh"
#include "mma.cuh"
#include "tma.h"
#include "utils.cuh"
#include "utils.h"

#ifndef GENERATE_CUBIN
#include <cuda_runtime.h>

#include "hostUtils.h"
#include "tensorMap.h"
#endif

#define USE_REG_Q 1

__constant__ constexpr XQAKernelType kernelType = XQAKernelType::kSM120_MLA;

inline constexpr bool allowMultipleInputTokens = true;

inline constexpr uint32_t partElemsK = 64;  // @fixme: change this to 128 to save L2 traffic
inline constexpr uint32_t nbKParts = exactDiv(validElemsPerKHead, partElemsK);
inline constexpr uint32_t nbQParts = nbKParts;

inline constexpr uint32_t tokensPerTile = 64;
inline constexpr uint32_t partElemsV = 128;
inline constexpr uint32_t nbVSplit = 2;
inline constexpr uint32_t gemm1V = exactDiv(validElemsPerVHead, nbVSplit);
inline constexpr uint32_t nbProducerCtasPerCga = nbVSplit;

inline constexpr uint32_t multiBlockMinNbTilesPerCta = 2;
inline constexpr uint32_t multiBlockMinNbTiles = multiBlockMinNbTilesPerCta * 2;

using MathElem = CacheElem;
inline constexpr uint32_t mathElemBytes = sizeof(MathElem);
inline constexpr uint32_t grainsPerPartK = exactDiv(partElemsK * mathElemBytes, grainBytes);

inline constexpr uint32_t grainElems = exactDiv(grainBytes, mathElemBytes);

inline constexpr float xScale = 1.f / kE4M3_MAX;
__constant__ constexpr float rcpXScale = kE4M3_MAX;

inline constexpr uint32_t nbRegsForIOWarps = 32;
inline constexpr uint32_t nbRegsForMathWarps = 232;

inline constexpr bool computeRowSumFromF8 = true;

struct KVTilePartLoader {
  static_assert(tokensPerPage % tokensPerTile == 0 || tokensPerTile % tokensPerPage == 0);
  static inline constexpr uint32_t nbPagesPerTile =
      tokensPerTile >= tokensPerPage ? exactDiv(tokensPerTile, tokensPerPage) : 1;

  static inline constexpr uint32_t const nbKHeads = 1;
  KVCacheList<usePagedKVCache> const& cacheList;
  uint32_t const idxReq;
  static inline constexpr uint32_t const idxHeadGrp = 0;

  CUtensorMap const& tensorMap;
  // if greater than 1, then we need unrolling for the loading loop. Seems 1 is fine for latency.
  static inline constexpr uint32_t nbPageBuffers = 1;
  uint32_t const nbPages;  // for bound check
  Vec<KVCachePageIndex, nbPagesPerTile> pageBuffers[nbPageBuffers];
  uint32_t idxTileRef = ~0U;  // idxTile used to load the pages
  uint32_t const baseOffset;

  __device__ KVTilePartLoader(KVCacheList<usePagedKVCache> const& cacheList, uint32_t idxReq,
                              CUtensorMap const& tensorMap, uint32_t nbPages);
  // tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache
  template <uint32_t nbTokens, uint32_t grainsPerPart, bool alignedForSwizzle>
  __device__ void loadData(Array2D<LdGrain, nbTokens, grainsPerPart, alignedForSwizzle>& dst,
                           uint32_t idxTile, uint32_t idxElemBeg, CtaBarrier& bar,
                           uint32_t idxPageBuf);

  __device__ void loadPages(uint32_t idxTile, uint32_t idxPageBuf);
};

__device__ inline KVTilePartLoader::KVTilePartLoader(KVCacheList<usePagedKVCache> const& cacheList,
                                                     uint32_t idxReq, CUtensorMap const& tensorMap,
                                                     uint32_t nbPages)
    : cacheList{cacheList},
      idxReq{idxReq},
      tensorMap{tensorMap},
      nbPages{nbPages},
      baseOffset{idxReq * cacheList.maxNbPagesPerSeq} {
#pragma unroll
  for (auto& pageBuffer : pageBuffers) {
    pageBuffer.fill(kBAD_PAGE_INDEX);
  }
}

// tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache
template <uint32_t nbTokens, uint32_t grainsPerPart, bool alignedForSwizzle>
__device__ inline void KVTilePartLoader::loadData(
    Array2D<LdGrain, nbTokens, grainsPerPart, alignedForSwizzle>& dst, uint32_t idxTile,
    uint32_t idxElemBeg, CtaBarrier& bar, uint32_t idxPageBuf) {
  static_assert(nbTokens == tokensPerTile);
  assert(idxTile == idxTileRef);
  auto const& pages = pageBuffers[idxPageBuf];
  if constexpr (nbTokens < tokensPerPage) {
    assert(nbPagesPerTile == 1);
    uint32_t const offset = nbTokens * (idxTile % exactDiv(tokensPerPage, nbTokens));
    if (warpElectSync()) {
      tma::loadAsync(&dst, tensorMap, DimsLE<4>{idxElemBeg, idxHeadGrp, offset, (uint32_t)pages[0]},
                     bar);
    }
  } else {
#pragma unroll
    for (uint32_t i = 0; i < nbPagesPerTile; i++) {
      if (warpElectSync()) {
        tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap,
                       DimsLE<4>{idxElemBeg, idxHeadGrp, 0, (uint32_t)pages[i]}, bar);
      }
    }
  }
}

__device__ inline void KVTilePartLoader::loadPages(uint32_t idxTile, uint32_t idxPageBuf) {
  uint32_t const idxPageBeg = tokensPerTile >= tokensPerPage
                                  ? nbPagesPerTile * idxTile
                                  : idxTile / exactDiv(tokensPerPage, tokensPerTile);
  auto& pages = pageBuffers[idxPageBuf];
#pragma unroll
  for (uint32_t i = 0; i < nbPagesPerTile; i++) {
    uint32_t const idxPage = idxPageBeg + i;
    pages[i] =
        idxPage < nbPages ? cacheList.kvCachePageList[baseOffset + idxPage] : kBAD_PAGE_INDEX;
  }
  idxTileRef = idxTile;
}

using Mat16x32 = Vec<uint32_t, 4>;

template <uint32_t srcRows, uint32_t srcCols>
class Mat16x32Loader {
 public:
  using Src = Array2D<LdGrain, srcRows, srcCols>;

  // default r and c are for mat A.
  __device__ inline Mat16x32Loader(Src const& src, uint32_t baseRow, uint32_t idxInstK,
                                   uint32_t r = laneId() % 16, uint32_t c = laneId() / 16)
      : src{src}, baseRow{baseRow}, idxInstK{idxInstK}, r{r}, c{c}, basePtr{getPtrRef(0)} {
    static_assert((grainBytes * srcCols * qmmaShape.m) % 1024 == 0);
  }

  __device__ inline Mat16x32 load(uint32_t idxInstM) const {
    return ldmatrix<false, 4>(getPtr(idxInstM));
  }

  template <uint32_t tileM>
  __device__ inline Vec<Mat16x32, exactDiv(tileM, qmmaShape.m)> loadWholeCol() const {
    uint32_t const nbInstM = exactDiv(tileM, qmmaShape.m);
    Vec<Mat16x32, nbInstM> ret;
#pragma unroll
    for (uint32_t i = 0; i < nbInstM; i++) {
      ret[i] = load(i);
    }
    return ret;
  }

  __device__ inline LdGrain const* getPtr(uint32_t idxInstM) const {
    return checkedVal(basePtr + idxInstM * qmmaShape.m * srcCols, getPtrRef(idxInstM));
  }

 private:
  __device__ inline LdGrain const* getPtrRef(uint32_t idxInstM) const {
    return &src.template at<true>(baseRow + idxInstM * qmmaShape.m + r,
                                  idxInstK * exactDiv(qmmaShape.k, grainElems) + c);
  }

  Src const& src;
  uint32_t const baseRow;
  uint32_t const idxInstK;
  uint32_t const r;
  uint32_t const c;
  LdGrain const* const basePtr;
};

using InstAcc = Array2D<float, 2, 2>;

using XBuffer = Array2D<LdGrain, headGrpSize, exactDiv(mathElemBytes* tokensPerTile, grainBytes)>;

struct CgaXBuffer {
  XBuffer x;
  Vec<float, headGrpSize> rowSum;
  Vec<float, headGrpSize> rowMaxLog2e;
};

struct PingPongMutex {
  using ShmStorage = CtaBarrier[2];
  ShmStorage& barriers;
  uint32_t const idxGrp;
  bool skipWait = false;

  static __device__ inline void initStorage(ShmStorage& barriers, uint32_t thrdsPerGrp) {
    new (&barriers[0]) CtaBarrier(thrdsPerGrp);
    new (&barriers[1]) CtaBarrier(thrdsPerGrp);
    barriers[0].arrive(thrdsPerGrp);
  }

  __device__ inline PingPongMutex(ShmStorage& shmStorage, uint32_t idxGrp)
      : barriers{shmStorage}, idxGrp{idxGrp} {}

  __device__ inline void test_lock(uint32_t iter) {
    skipWait = barriers[idxGrp].test_wait_parity(toParity<1>(iter));
  }

  __device__ inline void lock(uint32_t iter) {
    if (!skipWait) {
      barriers[idxGrp].wait_parity(toParity<1>(iter));
    }
  }

  __device__ inline void unlock() {
    barriers[idxGrp ^ 1U].arrive();
    skipWait = false;
  }
};

struct PartialResult {
  static constexpr uint32_t nbChunks = 4;
  static constexpr uint32_t nbRowsPerChunk = exactDiv(headGrpSize, nbChunks);

  struct Chunk {
    Vec<OutputHead, nbRowsPerChunk> data;
    Vec<float, nbRowsPerChunk> rowSum;
    Vec<float, nbRowsPerChunk> rowMaxLog2e;
  };

  Chunk chunks[nbChunks];
};

constexpr uint32_t nbMathWarpsA = 8;
constexpr uint32_t nbComputeWarpsB = 8;
constexpr uint32_t nbMathGrpsA = 2;
constexpr uint32_t nbMathWarpsB = 8;

constexpr uint32_t nbMultiBlockBufs = 2;
constexpr uint32_t multiBlockMathWarps = 8;

constexpr bool useRegQ = USE_REG_Q;

struct SharedMemA {
  static inline constexpr uint32_t nbKBufs = 12;

  static inline constexpr uint32_t regQParts = (useRegQ ? 4 : 0);
  static inline constexpr uint32_t shmQParts = nbQParts - regQParts;

  using ShmQPart = Array2D<LdGrain, headGrpSize, grainsPerPartK>;
  using ShmKPart = Array2D<LdGrain, tokensPerTile, grainsPerPartK>;

  Vec<ShmQPart, shmQParts> q;
  ShmKPart k[nbKBufs];

  // single buffer reused by two groups. sendX() warp will arbitrate the order of x buffer access
  // via two xBars.
  CgaXBuffer x;

  // scaled by log2e. Write by last CGA iteration (from the other producer CTA) and read by current
  // producer CTA.
  Vec<float, headGrpSize> rowMaxLog2e;
  // sync rowMaxLog2e between two producer CTAs and .consumed means the buffer for next iteration
  // (in next producer) is ready. The 4 groups from 2 producers CTAs form a ring
  CgaBarrier rowMaxLog2eBar[nbMathGrpsA];

  PingPongMutex::ShmStorage tensorCoreMutex;

  CtaBarrierPair kBars[nbKBufs];
  static inline constexpr uint32_t nbXBars = nbMathGrpsA;
  CtaBarrierPair xBars[nbXBars];
#if USE_REG_Q
  CtaBarrierPair regQBar;
#endif
  CtaBarrier shmQBar;
  CgaBarrier cgaXBufConsumed;  // for X

  CtaBarrierPair multiBlockBars[nbMultiBlockBufs];

  __device__ inline void invalidateBarriers(uint32_t thrdIdx) {
    constexpr uint32_t nbBars = (useRegQ ? 12 : 10) + 2 * (nbKBufs + nbXBars);
#ifndef __CUDACC_RTC__
    constexpr uint32_t nbBarsRef =
        exactDiv(offsetof(SharedMemA, qkScaleLog2e) - offsetof(SharedMemA, rowMaxLog2eBar), 8);
    static_assert(nbBars == nbBarsRef);
#endif
    if (thrdIdx < nbBars) {
      reinterpret_cast<CtaBarrier*>(&rowMaxLog2eBar[0])[thrdIdx].~CtaBarrier();
    }
  }

  __device__ inline Vec<PartialResult::Chunk, nbMultiBlockBufs>& getMultiBlockBufs() {
#ifndef __CUDACC_RTC__
    assert(sizeof(Vec<PartialResult::Chunk, nbMultiBlockBufs>) <
           offsetof(SharedMemA, rowMaxLog2eBar));
#endif
    return *reinterpret_cast<Vec<PartialResult::Chunk, nbMultiBlockBufs>*>(this);
  }

  float qkScaleLog2e;
  bool isLastSubSeq;
};

struct SharedMemB {
  static inline constexpr uint32_t nbXVBufs = 2;
  static inline constexpr uint32_t nbXBufs = nbXVBufs;
  static inline constexpr uint32_t nbVBufs = nbXVBufs;

  using VBuffer = Vec<Array2D<LdGrain, tokensPerTile, exactDiv(partElemsV, grainElems)>,
                      exactDiv(gemm1V, partElemsV)>;

  // x and v are using gemmK=128 per iteration. If we see high pressure on shared memory capacity,
  // we can change to 64 in the future.
  struct XVBuffer {
    VBuffer v;
    CgaXBuffer x;
    uint8_t
        pad[headGrpSize * 128 * 2 - sizeof(VBuffer) - sizeof(CgaXBuffer)];  // for output swizzling
  };

  XVBuffer xv[nbXVBufs];

  __device__ inline XBuffer& x(uint32_t idx) { return xv[idx].x.x; }

  __device__ inline VBuffer& v(uint32_t idx) { return xv[idx].v; }

  __device__ inline Vec<float, headGrpSize>& xRowSum(uint32_t idx) { return xv[idx].x.rowSum; }

  __device__ inline Vec<float, headGrpSize>& xRowMaxLog2e(uint32_t idx) {
    return xv[idx].x.rowMaxLog2e;
  }

  static inline constexpr uint32_t nbAccRowMaxSumCopies = 2;
  Vec<float, headGrpSize> accRowMaxLog2e[nbAccRowMaxSumCopies];
  Vec<float, headGrpSize> accRowSum[nbAccRowMaxSumCopies];

  CtaBarrierPair xBars[nbXBufs];
  CtaBarrierPair vBars[nbVBufs];

  CgaBarrier cgaXBufProduced[nbProducerCtasPerCga];
  CtaBarrier mathWarpsBar;

  CtaBarrierPair multiBlockBars[nbMultiBlockBufs];

  __device__ inline void invalidateBarriers(uint32_t thrdIdx) {
    constexpr uint32_t nbBars = 15;
#ifndef __CUDACC_RTC__
    constexpr uint32_t nbBarsRef =
        exactDiv(offsetof(SharedMemB, isLastSubSeq) - offsetof(SharedMemB, xBars), 8);
    static_assert(nbBars == nbBarsRef);
#endif
    if (thrdIdx < nbBars) {
      reinterpret_cast<CtaBarrier*>(&xBars[0])[thrdIdx].~CtaBarrier();
    }
  }

  __device__ inline Vec<PartialResult::Chunk, nbMultiBlockBufs>& getMultiBlockBufs() {
#ifndef __CUDACC_RTC__
    static_assert(sizeof(Vec<PartialResult::Chunk, nbMultiBlockBufs>) <
                  offsetof(SharedMemB, xBars));
#endif
    return *reinterpret_cast<Vec<PartialResult::Chunk, nbMultiBlockBufs>*>(this);
  }

  bool isLastSubSeq;
};

__device__ void mergePartialOutputs(uint32_t& semaphore,
                                    Vec<OutputHead, PartialResult::nbRowsPerChunk>& dst,
                                    PartialResult const* reqPartialResults, uint32_t nbSubSeq,
                                    uint32_t ctaRank, uint32_t warpRank, uint2 warpIdx,
                                    void* sharedMem);

struct KernelArgs {
  CUtensorMap const& tensorMapQ;  // MhaIOHead[nbQHeads * totalNbInputTokens]
  CUtensorMap const& tensorMapK;
  CUtensorMap const& tensorMapV;
  float const& qScale;
  OutputHead* __restrict__ const& output;  // [totalNbIntputTokens][nbQHeads]
  KVCacheList<usePagedKVCache> const& cacheList;
  uint32_t const& batchSize;
  float kvCacheScale;  // Same scale for K and V cache. Used only for int8/fp8 KV cache.
  Vec<CgaXBuffer, nbProducerCtasPerCga>* __restrict__ const&
      cgaXBuf;                                        // [totalNbInputTokens][maxNbSubSeq]
  uint32_t* __restrict__ const& semaphores;           // [totalNbInputTokens]
  PartialResult* __restrict__ const& partialResults;  // [totalNbInputTokens][maxNbSubSeq]
};

struct Producer {
  static inline constexpr uint32_t nbMathGrps = nbMathGrpsA;
  static inline constexpr uint32_t nbMathWarps = nbMathWarpsA;
  static inline constexpr uint32_t nbMathThrds = nbMathWarps * warp_size;
  static inline constexpr uint32_t warpsPerGrp = exactDiv(nbMathWarps, nbMathGrps);
  static inline constexpr uint32_t thrdsPerGrp = warpsPerGrp * warp_size;
  static inline constexpr uint2 warpTile = {tokensPerTile, exactDiv(headGrpSize, warpsPerGrp)};
  using WarpAcc = WarpAccT<warpTile.y, warpTile.x>;
  using ThrdRegRowMax = ThrdRegRowMaxT<warpTile.y>;
  using QuadRegRowMax = QuadRegRowMaxT<warpTile.y>;

  KernelArgs const& args;
  SharedMemA& smem;
  uint32_t const maxNbSubSeq;
  uint32_t const idxReq;
  uint32_t const idxInputTokenGlobal;
  uint32_t const nbSubSeq;
  uint32_t const idxSubSeq;
  uint32_t const seqLen;
  uint32_t const ctaRank;
  uint32_t const warpRank;
  uint2 const warpIdx;

  __device__ inline Producer(KernelArgs const& args, SharedMemA& smem, uint32_t const maxNbSubSeq,
                             uint32_t const idxReq, uint32_t idxInputTokenGlobal,
                             uint32_t const seqLen, uint32_t const nbSubSeq,
                             uint32_t const idxSubSeq, uint32_t ctaRank, uint32_t const warpRank,
                             uint2 const warpIdx)
      : args(args),
        smem(smem),
        maxNbSubSeq(maxNbSubSeq),
        idxReq(idxReq),
        idxInputTokenGlobal(idxInputTokenGlobal),
        seqLen(seqLen),
        nbSubSeq(nbSubSeq),
        idxSubSeq(idxSubSeq),
        ctaRank(ctaRank),
        warpRank(warpRank),
        warpIdx(warpIdx) {
#ifndef NDEBUG
    if (threadIdx.x == 0) {
      asm("st.bulk.weak [%0], %1, 0;\n" ::"l"(&smem), "n"(sizeof(SharedMemA)) : "memory");
    }
    __syncthreads();
#endif
    if (threadIdx.x == 0) {
      smem.qkScaleLog2e = args.qScale * args.kvCacheScale * log2e;
    }

    if (threadIdx.x < headGrpSize) {
      smem.rowMaxLog2e[threadIdx.x] = safeInitRowMax;
    }
    if (warpElectSync()) {
      if (warpRank < SharedMemA::nbKBufs) {
        auto& b = smem.kBars[warpRank];
        b.initialize(1, thrdsPerGrp);
        b.consumed.arrive<Scope::CTA, ArriveOrder::RELAXED>(thrdsPerGrp);
      }
      if (warpRank < SharedMemA::nbXBars) {
        auto& b = smem.xBars[warpRank];
        b.initialize(thrdsPerGrp, 1);
      }
#if USE_REG_Q
      if (warpRank == 0) {
        smem.regQBar.initialize(1, nbMathThrds);
        smem.regQBar.consumed.arrive<Scope::CTA, ArriveOrder::RELAXED>(nbMathThrds);
      }
#endif
      if (warpRank < nbMathGrpsA) {
        auto& b = smem.rowMaxLog2eBar[warpRank];
        init(&b, thrdsPerGrp);
      }
      if (ctaRank == 0 && warpRank == 0) {
        smem.rowMaxLog2eBar[0].arrive<Scope::CTA, ArriveOrder::RELAXED>(thrdsPerGrp);
      }
      if (warpRank == 0) {
        init(&smem.shmQBar, 1);
        init(&smem.cgaXBufConsumed, 1 * nbVSplit);
        smem.cgaXBufConsumed.arrive<Scope::CTA, ArriveOrder::RELAXED>(1 * nbVSplit);
        PingPongMutex::initStorage(smem.tensorCoreMutex, thrdsPerGrp);
      }
      if (nbSubSeq > 1 && warpRank < nbMultiBlockBufs) {
        auto& b = smem.multiBlockBars[warpRank];
        b.initialize(1, warp_size * multiBlockMathWarps);
        b.consumed.arrive<Scope::CTA, ArriveOrder::RELAXED>(warp_size * multiBlockMathWarps);
      }
    }
    clusterBarArrive();
    clusterBarWait();
  }

  __device__ inline ~Producer() {
    clusterBarArrive();
    clusterBarWait();
    smem.invalidateBarriers(threadIdx.x);
  }

  __device__ inline void run() {
    if (warpIdx.y == 2) {  // IO warps
      asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" ::"n"(nbRegsForIOWarps));
      if (warpIdx.x == 0) {  // q
        loadQ();
      } else if (warpIdx.x == 1) {  // k
        loadK();
      } else if (warpIdx.x == 2) {  // x
        sendX();
      }
    } else {  // Compute warps
      asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" ::"n"(nbRegsForMathWarps));
      compute();
    }
    if (nbSubSeq > 1) {
      mergePartialOutputs(args.semaphores[idxInputTokenGlobal],
                          reinterpret_cast<Vec<OutputHead, PartialResult::nbRowsPerChunk>&>(
                              args.output[headGrpSize * idxInputTokenGlobal +
                                          PartialResult::nbRowsPerChunk * ctaRank]),
                          args.partialResults + maxNbSubSeq * idxInputTokenGlobal, nbSubSeq,
                          ctaRank, warpRank, warpIdx, &smem);
    }
  }

 private:
  __device__ inline uint32_t iterStride() const { return nbSubSeq * nbProducerCtasPerCga; }

  __device__ inline uint32_t idxTileBeg() const {
    return nbProducerCtasPerCga * idxSubSeq + ctaRank;
  }

  __device__ inline uint32_t nbTiles() const { return divUp(seqLen, tokensPerTile); }

  __device__ inline SharedMemB& getConsumerShm(uint32_t const idxConsumer) {
    return *mapa(reinterpret_cast<SharedMemB*>(&smem), nbProducerCtasPerCga + idxConsumer);
  };

  static constexpr uint32_t regQPartShmBeg = SharedMemA::shmQParts - SharedMemA::regQParts;

  __device__ inline void loadQ() {
#if USE_REG_Q
    static_assert(SharedMemA::regQParts <= SharedMemA::shmQParts);
    smem.regQBar.consumed.wait_parity(toParity<1>(0));
#pragma unroll 1
    for (uint32_t i = 0; i < SharedMemA::regQParts; i++) {
      if (warpElectSync()) {
        tma::loadAsync(&smem.q[regQPartShmBeg + i], args.tensorMapQ,
                       DimsLE<2>{partElemsK * i, headGrpSize * idxInputTokenGlobal},
                       smem.regQBar.produced);
      }
    }
    if (warpElectSync()) {
      smem.regQBar.produced.arrive_tx(sizeof(SharedMemA::ShmQPart) * SharedMemA::regQParts);
    }
#endif
#pragma unroll 1
    for (uint32_t i = 0; i < SharedMemA::shmQParts; i++) {
      uint32_t const idxPart = SharedMemA::regQParts + i;
#if USE_REG_Q
      if (i == regQPartShmBeg) {
        smem.regQBar.consumed.wait_parity(toParity<1>(1));
      }
#endif
      if (warpElectSync()) {
        tma::loadAsync(&smem.q[i], args.tensorMapQ,
                       DimsLE<2>{partElemsK * idxPart, headGrpSize * idxInputTokenGlobal},
                       smem.shmQBar);
      }
    }
    if (warpElectSync()) {
      smem.shmQBar.arrive_tx(sizeof(SharedMemA::ShmQPart) * SharedMemA::shmQParts);
    }
  }

  __device__ inline void loadK();

  __device__ inline void sendX();

  __device__ inline void compute() {
    uint32_t const grpIdx = warpIdx.y;
    uint32_t const tileBaseRow = warpTile.y * warpIdx.x;
    PingPongMutex tensorCoreMutex{smem.tensorCoreMutex, grpIdx};

    constexpr uint32_t partNbInstK = exactDiv(partElemsK, qmmaShape.k);
    using AtomA = Vec<uint32_t, 4>;  // for 16x32 data, working as mat A of QMMA.16832
    using RegQPartCol = Vec<AtomA, exactDiv(warpTile.y, qmmaShape.m)>;
    using RegQPart = Vec<RegQPartCol, partNbInstK>;
    using RegQ = Vec<RegQPart, SharedMemA::regQParts>;
    constexpr uint32_t tileNbAtomBx2 = exactDiv(tokensPerTile, qmmaShape.n * 2);
    using AtomBx2 = Vec<uint32_t, 4>;  // one AtomB is 8x32 and AtomBx2 is 16x32
    using RegKPartCol = Vec<AtomBx2, tileNbAtomBx2>;
    using RegKPart = Vec<RegKPartCol, partNbInstK>;

    uint32_t const lane = laneId();
    uint32_t const rA = lane % 16;
    uint32_t const cA = lane / 16;
    uint32_t const rB = (lane / 16) * 8 + lane % 8;
    uint32_t const cB = (lane % 16) / 8;
    auto loadRegQCol = [&](SharedMemA::ShmQPart const& q, uint32_t idxInstK) -> RegQPartCol {
      Mat16x32Loader const loaderQ(q, tileBaseRow, idxInstK, rA, cA);
      return loaderQ.loadWholeCol<warpTile.y>();
    };
    auto loadRegKCol = [&](SharedMemA::ShmKPart const& k, uint32_t idxInstK) -> RegKPartCol {
      Mat16x32Loader const loaderK(k, 0, idxInstK, rB, cB);
      return loaderK.loadWholeCol<warpTile.x>();
    };
    auto loadPart = [&](auto const& loadCol, auto const& shmPart) {
      mha::conditional_t<mha::is_same_v<SharedMemA::ShmQPart, mha::decay_t<decltype(shmPart)>>,
                         RegQPart, RegKPart>
          regPart;
#pragma unroll
      for (uint32_t idxInstK = 0; idxInstK < partNbInstK; idxInstK++) {
        regPart[idxInstK] = loadCol(shmPart, idxInstK);
      }
      return regPart;
    };

#if USE_REG_Q
    // load regQ
    smem.regQBar.produced.wait_parity(toParity<1>(0));
    RegQ regQ;
#pragma unroll
    for (uint32_t idxPart = 0; idxPart < SharedMemA::regQParts; idxPart++) {
      uint32_t const idxBuf = regQPartShmBeg + idxPart;
      regQ[idxPart] = loadPart(loadRegQCol, smem.q[idxBuf]);
    }
    smem.regQBar.consumed.arrive();
#endif
// main loop
#pragma unroll 1
    for (uint32_t grpIter = 0; true; grpIter++) {
      uint32_t const ctaIter = grpIdx + grpIter * nbMathGrps;
      uint32_t const idxTile = idxTileBeg() + iterStride() * ctaIter;
      if (idxTile >= nbTiles()) {
        break;
      }
      WarpAcc acc{};
      // wait until it's our turn
      tensorCoreMutex.lock(grpIter);
      BarWaiter kBarWaiter(smem.kBars, ctaIter * nbKParts + 0);
      kBarWaiter.testWait();
      RegQPart regQBuf;
#if USE_REG_Q
      static_assert(SharedMemA::regQParts > 0);
      regQBuf[0] = regQ[0][0];
#else
      regQBuf[0] = loadRegQCol(smem.q[0], 0);
#endif
      kBarWaiter.wait();
      RegKPart regKBuf;
      regKBuf[0] = loadRegKCol(smem.k[kBarWaiter.idxBuf], 0);

      auto shouldTestWait = [](uint32_t idxInstK, uint32_t idxAtomBx2) {
        return idxInstK == partNbInstK - 1 && idxAtomBx2 == tileNbAtomBx2 - 2;
      };
      BarWaiter kBarWaiterNext = kBarWaiter.next();
#if USE_REG_Q
#pragma unroll
      for (uint32_t idxPart = 0; idxPart < SharedMemA::regQParts; idxPart++) {
#pragma unroll
        for (uint32_t idxInstK = 0; idxInstK < partNbInstK; idxInstK++) {
          bool const prefetchNextPart = (idxInstK == partNbInstK - 1);
          uint32_t const idxPartPrefetch = prefetchNextPart ? idxPart + 1 : idxPart;
          uint32_t const idxInstKPrefetch = prefetchNextPart ? 0 : idxInstK + 1;
          bool const prefetch = (!prefetchNextPart || (idxPart < nbKParts - 1));

          if (prefetchNextPart) {
            kBarWaiter = kBarWaiterNext;
            kBarWaiterNext = kBarWaiter.next();
            if (prefetch) {
              kBarWaiter.wait();
            }
          }

          Mat16x32Loader const loaderK(smem.k[kBarWaiter.idxBuf], 0, idxInstKPrefetch, rB, cB);
#pragma unroll
          for (uint32_t idxAtomBx2 = 0; idxAtomBx2 < tileNbAtomBx2; idxAtomBx2++) {
            if (idxAtomBx2 == 2 && prefetch) {
              if (idxPartPrefetch < SharedMemA::regQParts) {
                regQBuf[idxInstKPrefetch] = regQ[idxPartPrefetch][idxInstKPrefetch];
              } else {
                regQBuf[idxInstKPrefetch] =
                    loadRegQCol(smem.q[idxPartPrefetch - SharedMemA::regQParts], idxInstKPrefetch);
              }
            }
            AtomBx2 const& atomBx2 = regKBuf[idxInstK][idxAtomBx2];
            regKBuf[idxInstKPrefetch][idxAtomBx2] = loaderK.load(idxAtomBx2);
            if (shouldTestWait(idxInstKPrefetch, idxAtomBx2) && prefetch) {
              kBarWaiterNext.testWait();
            }
#pragma unroll
            for (uint32_t i = 0; i < WarpAcc::rows; i++) {
#pragma unroll
              for (uint32_t j = 0; j < 2; j++) {
                mma<__nv_fp8_e4m3>(reinterpret_cast<float(&)[2][2]>(acc(i, 2 * idxAtomBx2 + j)),
                                   reinterpret_cast<uint32_t const(&)[2][2]>(regQBuf[idxInstK][i]),
                                   reinterpret_cast<uint32_t const(&)[2][1]>(atomBx2[2 * j]));
              }
            }
            if (prefetch) {
              regKBuf[idxInstKPrefetch][idxAtomBx2] = loaderK.load(idxAtomBx2);
            }
          }
          if (idxInstKPrefetch == partNbInstK - 1) {
            assert(prefetch);
            kBarWaiter.consumed();
          }
        }
      }
#endif
      if (ctaIter == 0) {
        smem.shmQBar.wait_parity(false);
      }
#pragma unroll
      for (uint32_t idxPart = SharedMemA::regQParts; idxPart < nbQParts; idxPart++) {
#pragma unroll
        for (uint32_t idxInstK = 0; idxInstK < partNbInstK; idxInstK++) {
          bool const prefetchNextPart = (idxInstK == partNbInstK - 1);
          uint32_t const idxPartPrefetch = prefetchNextPart ? idxPart + 1 : idxPart;
          uint32_t const idxInstKPrefetch = prefetchNextPart ? 0 : idxInstK + 1;
          bool const prefetch = (!prefetchNextPart || (idxPart < nbKParts - 1));

          if (prefetchNextPart) {
            kBarWaiter = kBarWaiterNext;
            kBarWaiterNext = kBarWaiter.next();
            if (prefetch) {
              kBarWaiter.wait();
            }
          }

          Mat16x32Loader const loaderK(smem.k[kBarWaiter.idxBuf], 0, idxInstKPrefetch, rB, cB);
#pragma unroll
          for (uint32_t idxAtomBx2 = 0; idxAtomBx2 < tileNbAtomBx2; idxAtomBx2++) {
            if (idxAtomBx2 == 2 && prefetch) {
              regQBuf[idxInstKPrefetch] =
                  loadRegQCol(smem.q[idxPartPrefetch - SharedMemA::regQParts], idxInstKPrefetch);
            }
            AtomBx2 const& atomBx2 = regKBuf[idxInstK][idxAtomBx2];
            if (shouldTestWait(idxInstKPrefetch, idxAtomBx2) && prefetch) {
              kBarWaiterNext.testWait();
            }
#pragma unroll
            for (uint32_t i = 0; i < WarpAcc::rows; i++) {
#pragma unroll
              for (uint32_t j = 0; j < 2; j++) {
                mma<__nv_fp8_e4m3>(reinterpret_cast<float(&)[2][2]>(acc(i, 2 * idxAtomBx2 + j)),
                                   reinterpret_cast<uint32_t const(&)[2][2]>(regQBuf[idxInstK][i]),
                                   reinterpret_cast<uint32_t const(&)[2][1]>(atomBx2[2 * j]));
              }
            }
            if (prefetch) {
              regKBuf[idxInstKPrefetch][idxAtomBx2] = loaderK.load(idxAtomBx2);
            }
          }
          if (idxInstKPrefetch == partNbInstK - 1) {
            assert(prefetch);
            kBarWaiter.consumed();
            if (idxPartPrefetch == nbKParts - 1) {
              tensorCoreMutex.unlock();  // let the other group to use tensor cores
            }
          }
        }
      }
      uint32_t const validTokens = seqLen - tokensPerTile * idxTile;
      if (validTokens < tokensPerTile) {
        applyMask(this_warp(), acc, 0, validTokens);
      }
      ThrdRegRowMax rowMaxLog2e;
      WarpAcc const xF32 = scaleAndSoftmax(rowMaxLog2e, acc, grpIdx, grpIter, tileBaseRow);

      auto& xBar = smem.xBars[grpIdx];
      bool const skipXBarWait = xBar.consumed.test_wait_parity(toParity<1>(grpIter));
      // convert to fp8
      WarpAcc const xF32Quant = xF32 * rcpXScale;
      // 0, 1, 8, 9,  2, 3, 10, 11,  4, 5, 12, 13,  6, 7, 14, 15
      Array2D<Array2D<uint32_t, 2, 1>, WarpAcc::rows, exactDiv(WarpAcc::cols, 2)> xF8;
#pragma unroll
      for (uint32_t i = 0; i < WarpAcc::rows; i++) {
#pragma unroll
        for (uint32_t m = 0; m < exactDiv(qmmaShape.m, 8); m++) {
#pragma unroll
          for (uint32_t j = 0; j < WarpAcc::cols; j += 2) {
            auto& dst = reinterpret_cast<__nv_fp8x2_e4m3(&)[2]>(xF8(i, j / 2)(m, 0));
            dst[0] = __nv_fp8x2_e4m3(float2{xF32Quant(i, j)(m, 0), xF32Quant(i, j)(m, 1)});
            dst[1] = __nv_fp8x2_e4m3(float2{xF32Quant(i, j + 1)(m, 0), xF32Quant(i, j + 1)(m, 1)});
          }
        }
      }
      // use tensor core to compute rowSum
      ThrdRegRowMax const rowSum =
          computeRowSumFromF8 ? computeRowSumF8<warpTile.y, warpTile.x>(this_warp(), xF8)
                              : computeRowSumF32<warpTile.y, warpTile.x>(this_warp(), xF32);

      // store xF8 and rowSum into L2 scratch buffer
      if (!skipXBarWait) {
        xBar.consumed.wait_parity(toParity<1>(grpIter));
      }
      storeRowMax<warpTile.y>(smem.x.rowMaxLog2e, rowMaxLog2e, tileBaseRow, lane);
      storeRowMax<warpTile.y>(smem.x.rowSum, rowSum, tileBaseRow, lane);
      storeOrderedXToShm(smem.x.x, xF8, tileBaseRow, lane);
      xBar.produced.arrive();
    }
  }

  __device__ inline WarpAcc scaleAndSoftmax(ThrdRegRowMax& rowMaxLog2e, WarpAcc const& acc,
                                            uint32_t grpIdx, uint32_t grpIter,
                                            uint32_t tileBaseRow);

  __device__ inline void storeOrderedXToShm(
      XBuffer& dst,
      Array2D<Array2D<uint32_t, 2, 1>, WarpAcc::rows, exactDiv(WarpAcc::cols, 2)> const& src,
      uint32_t const tileBaseRow, uint32_t const lane = laneId());
};

__device__ inline void Producer::loadK() {
  KVTilePartLoader loader{args.cacheList, idxReq, args.tensorMapK, divUp(seqLen, tokensPerPage)};

#pragma unroll 1
  for (uint32_t iter = 0; true; iter++) {
    uint32_t const idxTile = idxTileBeg() + iterStride() * iter;
    if (idxTile >= nbTiles()) {
      break;
    }
    uint32_t const idxPageBuf = iter % KVTilePartLoader::nbPageBuffers;
    loader.loadPages(idxTile, idxPageBuf);
#pragma unroll 1
    for (uint32_t idxPart = 0; idxPart < nbKParts; idxPart++) {
      uint32_t const idxPartGlobal = iter * nbKParts + idxPart;
      uint32_t const idxBuf = idxPartGlobal % SharedMemA::nbKBufs;
      auto& bar = smem.kBars[idxBuf];
      bar.consumed.wait_parity(toParity<SharedMemA::nbKBufs>(idxPartGlobal));
      loader.loadData(smem.k[idxBuf], idxTile, partElemsK * idxPart, bar.produced, idxPageBuf);
      if (warpElectSync()) {
        bar.produced.arrive_tx(sizeof(SharedMemA::ShmKPart));
      }
    }
  }
}

__device__ inline void Producer::sendX() {
  // let group 0 to produce first.
  if (warpElectSync()) {
    smem.xBars[0].consumed.arrive();
  }
  for (uint32_t iter = 0; true; iter++) {
    uint32_t const idxTile = idxTileBeg() + iterStride() * iter;
    if (idxTile >= nbTiles()) {
      break;
    }
    uint32_t const idxBar = iter % SharedMemA::nbXBars;
    auto& xBar = smem.xBars[idxBar];
    xBar.produced.wait_parity(toParity<SharedMemA::nbXBars>(iter));
    smem.cgaXBufConsumed.wait_parity(toParity<1>(iter));
    if (warpElectSync()) {
      auto& dst = args.cgaXBuf[nbSubSeq * idxInputTokenGlobal + idxSubSeq][ctaRank];
      tma::store1DAsync(&dst, &smem.x, sizeof(CgaXBuffer));
      tma::commitGroup();
      tma::waitGroup<0>();
      // it's turn for the other math group to produce.
      uint32_t const idxBarNext = (iter + 1) % SharedMemA::nbXBars;
      auto& xBarNext = smem.xBars[idxBarNext];
      xBarNext.consumed.arrive();
      asm volatile("fence.release.cluster;\n");
#pragma unroll
      for (uint32_t i = 0; i < nbVSplit; i++) {
        auto& producedBar = getConsumerShm(i).cgaXBufProduced[ctaRank];
        producedBar.arrive<Scope::CGA, ArriveOrder::RELAXED>();
      }
    }
  }
}

__device__ inline Producer::WarpAcc Producer::scaleAndSoftmax(ThrdRegRowMax& rowMaxLog2e,
                                                              WarpAcc const& acc, uint32_t grpIdx,
                                                              uint32_t grpIter,
                                                              uint32_t tileBaseRow) {
  uint32_t const ctaIter = grpIdx + grpIter * nbMathGrps;
  uint32_t const cgaIter = ctaRank + ctaIter * nbProducerCtasPerCga;
  auto const warp = this_warp();
  uint32_t const lane = laneId();
  uint32_t const idxProducer = ctaRank;
  assert(ctaRank < nbProducerCtasPerCga);

  float const qkScaleLog2e = smem.qkScaleLog2e;
  bool const skipWaitLastShmRowMax =
      smem.rowMaxLog2eBar[grpIdx].test_wait_parity(toParity<1>(grpIter));
  QuadRegRowMax const tileRowMaxLog2e = computeRowMax<warpTile.y, warpTile.x>(acc) * qkScaleLog2e;
  // get max with previous CTA's rowMax
  if (!skipWaitLastShmRowMax) {
    smem.rowMaxLog2eBar[grpIdx].wait_parity(toParity<1>(grpIter));
  }
  auto const lastRowMaxLog2e = loadShmRowMax<warpTile.y>(smem.rowMaxLog2e, tileBaseRow, lane);

  auto const quadRowMaxLog2e = fmaxf(tileRowMaxLog2e, replicateForQuad(warp, lastRowMaxLog2e));

  // transfer new row max to the other producer CTA for next iteration
  SharedMemA& smemNext = mapa(smem, ctaRank ^ 1U);
  CgaBarrier& nextRowMaxLog2eBar =
      smemNext.rowMaxLog2eBar[(cgaIter + 1) % (nbMathGrps * nbProducerCtasPerCga) / nbMathGrps];
  rowMaxLog2e = dedupFromQuad(warp, quadRowMaxLog2e);
  storeRowMaxAsync<warpTile.y>(nextRowMaxLog2eBar, smemNext.rowMaxLog2e, rowMaxLog2e, tileBaseRow,
                               lane);
  nextRowMaxLog2eBar.arrive_tx_relaxed(
      sizeof(rowMaxLog2e));  // notify that the next CTA can read rowMax now.

  WarpAcc x;
// apply softmax
#pragma unroll
  for (uint32_t m = 0; m < acc.rows; m++) {
#pragma unroll
    for (uint32_t i = 0; i < InstAcc::rows; i++) {
      float const maxVal = quadRowMaxLog2e[m * InstAcc::rows + i];
#pragma unroll
      for (uint32_t n = 0; n < acc.cols; n++) {
#pragma unroll
        for (uint32_t j = 0; j < InstAcc::cols; j++) {
          float elem = acc(m, n)(i, j);
          assert(maxVal >= elem * qkScaleLog2e);
          x(m, n)(i, j) = exp2f(elem * qkScaleLog2e - maxVal);
        }
      }
    }
  }

  return x;
}

__device__ inline void Producer::storeOrderedXToShm(
    XBuffer& dst,
    Array2D<Array2D<uint32_t, 2, 1>, WarpAcc::rows, exactDiv(WarpAcc::cols, 2)> const& src,
    uint32_t const tileBaseRow, uint32_t const lane) {
  uint32_t const r = lane % 16;
  uint32_t const c = lane / 16;
  using Src = mha::decay_t<decltype(src)>;
  LdGrain* ptrs[exactDiv(Src::cols, 2)][Src::rows];
#pragma unroll
  for (uint32_t idxInstK = 0; idxInstK < exactDiv(Src::cols, 2); idxInstK++) {
    Mat16x32Loader const loader(dst, tileBaseRow, idxInstK, r, c);
#pragma unroll
    for (uint32_t idxInstM = 0; idxInstM < Src::rows; idxInstM++) {
      auto const p = const_cast<LdGrain*>(loader.getPtr(idxInstM));
      stmatrix<false, 4>(p, reinterpret_cast<LdGrain const&>(src(idxInstM, idxInstK * 2)));
      ptrs[idxInstK][idxInstM] = p;
    }
  }
  // reorder from 0, 1, 8, 9,  2, 3, 10, 11,  4, 5, 12, 13,  6, 7, 14, 15
  // to 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
  __syncwarp();
#pragma unroll
  for (uint32_t idxInstK = 0; idxInstK < exactDiv(Src::cols, 2); idxInstK++) {
#pragma unroll
    for (uint32_t idxInstM = 0; idxInstM < Src::rows; idxInstM++) {
      auto const p = ptrs[idxInstK][idxInstM];
      auto const i = *p;
      LdGrain const o = {
          prmt(i[0], i[1], PermuteOrder{0, 1, 4, 5}), prmt(i[2], i[3], PermuteOrder{0, 1, 4, 5}),
          prmt(i[0], i[1], PermuteOrder{2, 3, 6, 7}), prmt(i[2], i[3], PermuteOrder{2, 3, 6, 7})};
      *p = o;
    }
  }
}

struct Consumer {
  static inline constexpr uint32_t nbMathWarps = nbMathWarpsB;
  static inline constexpr uint32_t nbMathThrds = warp_size * nbMathWarps;
  static inline constexpr uint2 ctaShape = {2, 4};
  static_assert(SharedMemB::nbAccRowMaxSumCopies == ctaShape.x);
  static_assert(ctaShape.x * ctaShape.y == nbMathWarps);
  static inline constexpr uint2 warpTile = {exactDiv(gemm1V, ctaShape.x),
                                            exactDiv(headGrpSize, ctaShape.y)};

  static inline constexpr uint32_t nbWarpOutSwizzleBuf = nbMathWarps;
  using WarpOutSwizzleBuf =
      Array2D<LdGrain,
              exactDiv(exactDiv(sizeof(SharedMemB::XVBuffer), sizeof(OutputElem) * warpTile.x),
                       nbMathWarps),
              exactDiv(sizeof(OutputElem) * warpTile.x, grainBytes)>;
  static_assert(WarpOutSwizzleBuf::rows % 8 == 0);

  using WarpAcc = WarpAccT<warpTile.y, warpTile.x>;
  using ThrdRegRowMax = ThrdRegRowMaxT<warpTile.y>;
  using UniformNeedRescaleMask = Vec<uint32_t, divUp(warpTile.y, warp_size)>;

  KernelArgs const& args;
  SharedMemB& smem;
  uint32_t const maxNbSubSeq;
  uint32_t const idxReq;
  uint32_t const idxInputTokenGlobal;
  uint32_t const nbSubSeq;
  uint32_t const idxSubSeq;
  uint32_t const seqLen;
  uint32_t const ctaRank;
  uint32_t const warpRank;
  uint2 const warpIdx;

  __device__ inline uint32_t iterStride() const { return nbSubSeq * nbProducerCtasPerCga; }

  __device__ inline uint32_t idxTileBeg() const { return nbProducerCtasPerCga * idxSubSeq; }

  __device__ inline uint32_t nbTiles() const { return divUp(seqLen, tokensPerTile); }

  __device__ inline uint32_t idxConsumer() const { return ctaRank - 2; }

  __device__ inline Consumer(KernelArgs const& args, SharedMemB& smem, uint32_t const maxNbSubSeq,
                             uint32_t const idxReq, uint32_t const idxInputTokenGlobal,
                             uint32_t const seqLen, uint32_t const nbSubSeq,
                             uint32_t const idxSubSeq, uint32_t ctaRank, uint32_t const warpRank,
                             uint2 const warpIdx)
      : args(args),
        smem(smem),
        maxNbSubSeq(maxNbSubSeq),
        idxReq(idxReq),
        idxInputTokenGlobal(idxInputTokenGlobal),
        seqLen(seqLen),
        nbSubSeq(nbSubSeq),
        idxSubSeq(idxSubSeq),
        ctaRank(ctaRank),
        warpRank(warpRank),
        warpIdx(warpIdx) {
#ifndef NDEBUG
    if (threadIdx.x == 0) {
      asm("st.bulk.weak [%0], %1, 0;\n" ::"l"(&smem), "n"(sizeof(SharedMemB)) : "memory");
    }
    __syncthreads();
#endif
    if (threadIdx.x < headGrpSize) {
      for (uint32_t i = 0; i < SharedMemB::nbAccRowMaxSumCopies; i++) {
        smem.accRowMaxLog2e[i][threadIdx.x] = safeInitRowMax;
        smem.accRowSum[i][threadIdx.x] = 0;
      }
    }
    if (warpElectSync()) {
      if (warpRank < nbProducerCtasPerCga) {
        init(&smem.cgaXBufProduced[warpRank], 1);
      }
      if (warpRank < SharedMemB::nbXBufs) {
        auto& bar = smem.xBars[warpRank];
        bar.initialize(1, nbMathThrds);
        bar.consumed.arrive<Scope::CTA, ArriveOrder::RELAXED>(nbMathThrds);
      }
      if (warpRank < SharedMemB::nbVBufs) {
        auto& bar = smem.vBars[warpRank];
        bar.initialize(1, nbMathThrds);
        bar.consumed.arrive<Scope::CTA, ArriveOrder::RELAXED>(nbMathThrds);
      }
      if (warpRank == 0) {
        init(&smem.mathWarpsBar, warp_size * nbMathWarps);
      }
      if (nbSubSeq > 1 && warpRank < nbMultiBlockBufs) {
        auto& b = smem.multiBlockBars[warpRank];
        b.initialize(1, warp_size * multiBlockMathWarps);
        b.consumed.arrive<Scope::CTA, ArriveOrder::RELAXED>(warp_size * multiBlockMathWarps);
      }
    }
    clusterBarArrive();
    clusterBarWait();
  }

  __device__ inline ~Consumer() {
    clusterBarArrive();
    clusterBarWait();
    smem.invalidateBarriers(threadIdx.x);
  }

  __device__ inline void run() {
    if (warpIdx.y == 2) {
      asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" ::"n"(nbRegsForIOWarps));
      if (warpIdx.x == 0) {
        loadX();
      } else if (warpIdx.x == 1) {
        loadV();
      }
    } else {
      asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" ::"n"(nbRegsForMathWarps));
      compute();
    }
    if (nbSubSeq > 1) {
      mergePartialOutputs(args.semaphores[idxInputTokenGlobal],
                          reinterpret_cast<Vec<OutputHead, PartialResult::nbRowsPerChunk>&>(
                              args.output[headGrpSize * idxInputTokenGlobal +
                                          PartialResult::nbRowsPerChunk * ctaRank]),
                          args.partialResults + maxNbSubSeq * idxInputTokenGlobal, nbSubSeq,
                          ctaRank, warpRank, warpIdx, &smem);
    }
  }

  __device__ inline void loadX();
  __device__ inline void loadV();
  __device__ inline void compute();

  __device__ inline uint32_t iterToTile(uint32_t iter) const {
    return idxTileBeg() + iterStride() * (iter / 2) + iter % 2;
  }

  __device__ inline SharedMemA& getProducerShm(uint32_t idxProducer) const {
    return mapa(reinterpret_cast<SharedMemA&>(smem), idxProducer);
  }

  using WarpOutputTile =
      Array2D<uint32_t, InstAcc::rows * Consumer::WarpAcc::rows, Consumer::WarpAcc::cols>;
  __device__ inline WarpOutputTile finalize(WarpAcc const& acc, ThrdRegRowMax const& accRowSum,
                                            float xvScale, uint32_t lane = laneId());
  __device__ inline void storeOutput(Vec<OutputHead, warpTile.y>& dst, uint32_t dstBaseCol,
                                     WarpOutputTile const& regTile, WarpOutSwizzleBuf& swizzleBuf,
                                     uint32_t lane = laneId());
};

__device__ inline void Consumer::compute() {
  uint2 const tileIdx = {warpIdx.y, warpIdx.x};
  uint2 const tileBase = {tileIdx.x * warpTile.x, tileIdx.y * warpTile.y};

  constexpr uint32_t tileNbInstK = exactDiv(tokensPerTile, qmmaShape.k);
  constexpr uint32_t warpTileNbAtomBx2 = exactDiv(warpTile.x, qmmaShape.n * 2);

  uint32_t const lane = laneId();
  uint32_t const idxHalf = lane / 16;
  uint32_t const laneInHalf = lane % 16;
  uint32_t const rA = laneInHalf;
  uint32_t const cA = idxHalf;
  uint32_t const rB = lane;
  uint32_t const cB = 0;

  WarpAcc acc{};
  uint32_t idxXVBufLast{};
  for (uint32_t iter = 0; true; iter++) {
    uint32_t const idxTile = iterToTile(iter);
    if (idxTile >= nbTiles()) {
      break;
    }

    ThrdRegRowMax accRowMaxLog2e =
        loadShmRowMax<warpTile.y>(smem.accRowMaxLog2e[tileIdx.x], tileBase.y, lane);
    ThrdRegRowMax accRowSum =
        loadShmRowMax<warpTile.y>(smem.accRowSum[tileIdx.x], tileBase.y, lane);

    uint32_t const idxXBuf = iter % SharedMemB::nbXBufs;
    uint32_t const idxVBuf = iter % SharedMemB::nbVBufs;
    auto& xBar = smem.xBars[idxXBuf];
    auto& vBar = smem.vBars[idxVBuf];
    // @fixme: merge these two barriers and use test_wait_parity() early to avoid latency.
    bool const skipVBarWait = vBar.produced.test_wait_parity(toParity<SharedMemB::nbVBufs>(iter));
    xBar.produced.wait_parity(toParity<SharedMemB::nbXBufs>(iter));

    ThrdRegRowMax const xRowMaxLog2e =
        loadShmRowMax<warpTile.y>(smem.xRowMaxLog2e(idxXBuf), tileBase.y, lane);
    assert(all(accRowMaxLog2e <= xRowMaxLog2e));

    auto const needRescaleVec = (xRowMaxLog2e > accRowMaxLog2e);
    UniformNeedRescaleMask rescaleMask{};
#pragma unroll
    for (uint32_t i = 0; i < rescaleMask.size; i++) {
      rescaleMask[i] = __ballot_sync(~0U, needRescaleVec[i]);
    }
    bool const anyNeedRescale = any(rescaleMask != UniformNeedRescaleMask::filled(0));
    if (anyNeedRescale) {
      auto const scaleVec = exp2f(accRowMaxLog2e - xRowMaxLog2e);
#pragma unroll
      for (uint32_t m = 0; m < WarpAcc::rows; m++) {
#pragma unroll
        for (uint32_t i = 0; i < InstAcc::rows; i++) {
          uint8_t const mask =
              reinterpret_cast<uint8_t const(&)[2][2]>(rescaleMask[m / 2])[m % 2][i];
          bool const needRescale = (mask != 0);
          if (needRescale) {  // this branch is warp-uniform
            float const scale = __shfl_sync(~0U, scaleVec[m / 2], 16 * (m % 2) + 8 * i + lane / 4);
#pragma unroll
            for (uint32_t n = 0; n < WarpAcc::cols; n++) {
#pragma unroll
              for (uint32_t j = 0; j < InstAcc::cols; j++) {
                acc(m, n)(i, j) *= scale;
              }
            }
          }
        }
      }
      accRowSum = accRowSum * scaleVec;
    }
    accRowMaxLog2e = xRowMaxLog2e;
    storeRowMax<warpTile.y>(smem.accRowMaxLog2e[tileIdx.x], accRowMaxLog2e, tileBase.y, lane);
    if (!skipVBarWait) {
      vBar.produced.wait_parity(toParity<SharedMemB::nbVBufs>(iter));
    }
    auto const& xBuf = smem.x(idxXBuf);
    auto const& vBuf = smem.v(idxVBuf)[tileIdx.x];
    auto const xRowSum = loadShmRowMax<warpTile.y>(smem.xRowSum(idxXBuf), tileBase.y, lane);
    accRowSum = accRowSum + xRowSum;
    storeRowMax<warpTile.y>(smem.accRowSum[tileIdx.x], accRowSum, tileBase.y, lane);

#pragma unroll
    for (uint32_t idxInstK = 0; idxInstK < tileNbInstK; idxInstK++) {
      Mat16x32Loader const loaderX(xBuf, tileBase.y, idxInstK, rA, cA);
      Vec<Mat16x32, exactDiv(warpTile.y, qmmaShape.m)> const x = loaderX.loadWholeCol<warpTile.y>();
      using AtomB = Vec<uint32_t, 2>;
#pragma unroll
      for (uint32_t idxAtomBx2 = 0; idxAtomBx2 < warpTileNbAtomBx2; idxAtomBx2++) {
        auto const data = ldmatrix_16x16_trans<2>(
            &vBuf.template at<true>(qmmaShape.k * idxInstK + rB, idxAtomBx2 + cB));
        AtomB const v[2] = {data[0], data[2], data[1], data[3]};
#pragma unroll
        for (uint32_t i = 0; i < WarpAcc::rows; i++) {
#pragma unroll
          for (uint32_t j = 0; j < 2; j++) {
#if 1
            mma<__nv_fp8_e4m3>(
#else
            mmaF8_k32_2inst(
#endif
                reinterpret_cast<float(&)[2][2]>(acc(i, 2 * idxAtomBx2 + j)),
                reinterpret_cast<uint32_t const(&)[2][2]>(x[i]),
                reinterpret_cast<uint32_t const(&)[2][1]>(v[j]));
          }
        }
      }
    }
    bool const isLastIter = (iterToTile(iter + 1) >= nbTiles());
    if (isLastIter) {
      idxXVBufLast = idxXBuf;
      assert(idxXBuf == idxVBuf);
    } else {
      xBar.consumed.arrive();
      vBar.consumed.arrive();
    }
  }

  smem.mathWarpsBar.arrive();

  ThrdRegRowMax const accRowSum =
      loadShmRowMax<warpTile.y>(smem.accRowSum[tileIdx.x], tileBase.y, lane);
  float const xvScale = computeRowSumFromF8 ? args.kvCacheScale : args.kvCacheScale * xScale;
  WarpOutputTile const output = finalize(acc, accRowSum, xvScale, lane);

  bool const isMultiBlockMode = (nbSubSeq != 1);
  static_assert(PartialResult::nbRowsPerChunk == warpTile.y);
  auto& dst = isMultiBlockMode ? args.partialResults[maxNbSubSeq * idxInputTokenGlobal + idxSubSeq]
                                     .chunks[tileIdx.y]
                                     .data
                               : reinterpret_cast<Vec<OutputHead, warpTile.y>&>(
                                     args.output[headGrpSize * idxInputTokenGlobal + tileBase.y]);

  assert(warpRank < nbMathWarps);
  WarpOutSwizzleBuf& swizzleBuf = reinterpret_cast<Vec<WarpOutSwizzleBuf, nbWarpOutSwizzleBuf>&>(
      smem.xv[idxXVBufLast])[warpRank];
  // make sure all math warps have finished using XVBuffer.
  smem.mathWarpsBar.wait_parity(false);

  storeOutput(dst, gemm1V * idxConsumer() + tileBase.x, output, swizzleBuf, lane);
  if (isMultiBlockMode && tileIdx.x == 0) {
    ThrdRegRowMax const accRowMaxLog2e =
        loadShmRowMax<warpTile.y>(smem.accRowMaxLog2e[tileIdx.x], tileBase.y, lane);
    auto& chunk =
        args.partialResults[maxNbSubSeq * idxInputTokenGlobal + idxSubSeq].chunks[tileIdx.y];
#pragma unroll
    for (uint32_t i = 0; i < ThrdRegRowMax::size; i++) {
      chunk.rowMaxLog2e[warp_size * i + lane] = accRowMaxLog2e[i];
      chunk.rowSum[warp_size * i + lane] = accRowSum[i];
    }
  }
  smem.xBars[idxXVBufLast].consumed.arrive();
  smem.vBars[idxXVBufLast].consumed.arrive();
}

__device__ inline void Consumer::loadX() {
#pragma unroll 1
  for (uint32_t iter = 0; true; iter++) {
    uint32_t const idxTile = iterToTile(iter);
    if (idxTile >= nbTiles()) {
      break;
    }
    // @todo: merge these two barriers.
    uint32_t const idxScratchXBuf = iter % nbProducerCtasPerCga;
    auto& srcProducedBar = smem.cgaXBufProduced[idxScratchXBuf];
    srcProducedBar.wait_parity(toParity<nbProducerCtasPerCga>(iter));
    uint32_t const idxXBuf = iter % SharedMemB::nbXBufs;
    auto& xBar = smem.xBars[idxXBuf];
    xBar.consumed.wait_parity(toParity<SharedMemB::nbXBufs>(iter));
    if (warpElectSync()) {
      auto& src = args.cgaXBuf[nbSubSeq * idxInputTokenGlobal + idxSubSeq][idxScratchXBuf];
      auto& dst = smem.xv[idxXBuf].x;
      tma::loadLinearAsync(&dst, &src.x, sizeof(CgaXBuffer), xBar.produced);
      xBar.produced.arrive_tx(sizeof(CgaXBuffer));
      xBar.produced.wait_parity(toParity<SharedMemB::nbXBufs>(iter));
      uint32_t const idxProducer = idxScratchXBuf;
      // @fixme: check if this works. If it doesn't, randomly pick some data from dstX and dstRowSum
      // and use STAS + arrive_tx to avoid fence.
      getProducerShm(idxProducer).cgaXBufConsumed.arrive<Scope::CGA, ArriveOrder::RELAXED>();
    }
  }
}

__device__ inline void Consumer::loadV() {
  KVTilePartLoader loader(args.cacheList, idxReq, args.tensorMapV, divUp(seqLen, tokensPerPage));
  for (uint32_t iter = 0; true; iter++) {
    uint32_t const idxTile = iterToTile(iter);
    if (idxTile >= nbTiles()) {
      break;
    }
    uint32_t const idxPageBuf = iter % KVTilePartLoader::nbPageBuffers;
    loader.loadPages(idxTile, idxPageBuf);
    uint32_t const idxVBuf = iter % SharedMemB::nbVBufs;
    auto& vBar = smem.vBars[idxVBuf];
    vBar.consumed.wait_parity(toParity<SharedMemB::nbVBufs>(iter));
#pragma unroll
    for (uint32_t idxPart = 0; idxPart < SharedMemB::VBuffer::size; idxPart++) {
      loader.loadData(
          smem.v(idxVBuf)[idxPart], idxTile,
          gemm1V * idxConsumer() + exactDiv(gemm1V, SharedMemB::VBuffer::size) * idxPart,
          vBar.produced, idxPageBuf);
    }
    if (warpElectSync()) {
      vBar.produced.arrive_tx(sizeof(SharedMemB::VBuffer));
    }
  }
}

__device__ inline Array2D<uint32_t, InstAcc::rows * Consumer::WarpAcc::rows,
                          Consumer::WarpAcc::cols>
Consumer::finalize(WarpAcc const& acc, ThrdRegRowMax const& accRowSum, float const xvScale,
                   uint32_t const lane) {
  ThrdRegRowMax const scaleVec = 1.F / (accRowSum)*xvScale;
  WarpOutputTile ret;
#pragma unroll
  for (uint32_t m = 0; m < WarpAcc::rows; m++) {
#pragma unroll
    for (uint32_t i = 0; i < InstAcc::rows; i++) {
      uint32_t retRow = m * InstAcc::rows + i;
      float const scale = __shfl_sync(~0U, scaleVec[m / 2], 16 * (m % 2) + 8 * i + lane / 4);
#pragma unroll
      for (uint32_t n = 0; n < WarpAcc::cols; n++) {
        float data[InstAcc::cols];
#pragma unroll
        for (uint32_t j = 0; j < InstAcc::cols; j++) {
          data[j] = acc(m, n)(i, j) * scale;
        }
        assert(InstAcc::cols == 2);
        reinterpret_cast<__nv_bfloat162&>(ret(retRow, n)) =
            __float22bfloat162_rn(float2{data[0], data[1]});
      }
    }
  }
  return ret;
}

__device__ inline void Consumer::storeOutput(Vec<OutputHead, warpTile.y>& dst, uint32_t dstBaseCol,
                                             WarpOutputTile const& src,
                                             WarpOutSwizzleBuf& swizzleBuf, uint32_t lane) {
  using Dst = mha::decay_t<decltype(dst)>;
  static_assert(Dst::size == WarpOutputTile::rows * 8 && Dst::size % WarpOutSwizzleBuf::rows == 0);
  uint32_t const nbIters = exactDiv(Dst::size, WarpOutSwizzleBuf::rows);

  uint32_t const rS = lane % 8;
  uint32_t const cS = lane / 8;

  uint32_t const thrdsPerRow =
      exactDiv(sizeof(WarpOutSwizzleBuf::Elem) * WarpOutSwizzleBuf::cols, grainBytes);
  static_assert(thrdsPerRow <= 32);
  uint32_t const rL = lane / thrdsPerRow;
  uint32_t const cL = lane % thrdsPerRow;
#pragma unroll
  for (uint32_t iter = 0; iter < nbIters; iter++) {
#pragma unroll
    for (uint32_t j = 0; j < WarpOutputTile::cols; j += 4) {
      auto const baseSwzPtr = &swizzleBuf.template at<true>(rS, j + cS);
      constexpr uint32_t srcRowsPerIter = exactDiv(WarpOutputTile::rows, nbIters);
#pragma unroll
      for (uint32_t i = 0; i < srcRowsPerIter; i++) {
        static_assert(sizeof(WarpOutSwizzleBuf::Elem) * WarpOutSwizzleBuf::cols * 8 % 1024 == 0);
        auto const swzPtr = checkedVal(baseSwzPtr + WarpOutputTile::cols * 8 * i,
                                       &swizzleBuf.template at<true>(8 * i + rS, j + cS));
        stmatrix<false, 4>(
            swzPtr, reinterpret_cast<Vec<uint32_t, 4> const&>(src(srcRowsPerIter * iter + i, j)));
      }
    }
    __syncwarp();

    uint32_t const dstRowsPerIter = WarpOutSwizzleBuf::rows;
    uint32_t const rowsPerOp = exactDiv(warp_size, thrdsPerRow);
    LdGrain* const baseDstPtr = reinterpret_cast<LdGrain*>(
        &dst[dstRowsPerIter * iter + rL]
            [dstBaseCol + exactDiv(grainBytes, sizeof(OutputElem)) * cL]);
#pragma unroll
    for (uint32_t i = 0; i < dstRowsPerIter; i += rowsPerOp) {
      LdGrain* const dstPtr =
          checkedVal(baseDstPtr + i * exactDiv(sizeof(OutputHead), grainBytes),
                     reinterpret_cast<LdGrain*>(
                         &dst[dstRowsPerIter * iter + i + rL]
                             [dstBaseCol + exactDiv(grainBytes, sizeof(OutputElem)) * cL]));
      LdGrain* const srcPtr = &swizzleBuf.template at<true>(i + rL, cL);
      *dstPtr = *srcPtr;
    }
    __syncwarp();
  }
}

__device__ inline void mergePartialOutputs(uint32_t& semaphore,
                                           Vec<OutputHead, PartialResult::nbRowsPerChunk>& dst,
                                           PartialResult const* reqPartialResults,
                                           uint32_t nbSubSeq, uint32_t ctaRank, uint32_t warpRank,
                                           uint2 warpIdx, void* sharedMem) {
  assert(nbSubSeq > 1);
  clusterBarArrive();
  clusterBarWait();
  bool const isProducer = (ctaRank < nbProducerCtasPerCga);

  bool& shmIsLastSubSeq = isProducer ? static_cast<SharedMemA*>(sharedMem)->isLastSubSeq
                                     : static_cast<SharedMemB*>(sharedMem)->isLastSubSeq;

  if (ctaRank == 3 && threadIdx.x == 0) {
    uint32_t old;
    uint32_t const lastOld = nbSubSeq - 1;
    asm volatile("atom.relaxed.gpu.global.inc.u32 %0, [%1], %2;\n"
                 : "=r"(old)
                 : "l"(&semaphore), "r"(lastOld));
    bool const isLastSubSeq = (old == lastOld);
#pragma unroll
    for (uint32_t i = 0; i < nbProducerCtasPerCga; i++) {
      static_cast<SharedMemA*>(mapa(sharedMem, i))->isLastSubSeq = isLastSubSeq;
    }
    mapa(shmIsLastSubSeq, 2) = isLastSubSeq;
    shmIsLastSubSeq = isLastSubSeq;
  }
  clusterBarArrive();
  clusterBarWait();
  bool const isLastCga = shmIsLastSubSeq;
  if (!isLastCga) {
    return;
  }

  CtaBarrierPair(&bars)[nbMultiBlockBufs] =
      isProducer ? static_cast<SharedMemA*>(sharedMem)->multiBlockBars
                 : static_cast<SharedMemB*>(sharedMem)->multiBlockBars;
  Vec<PartialResult::Chunk, nbMultiBlockBufs>& shmBufs =
      isProducer ? static_cast<SharedMemA*>(sharedMem)->getMultiBlockBufs()
                 : static_cast<SharedMemB*>(sharedMem)->getMultiBlockBufs();

  constexpr uint32_t nbShmBufs = nbMultiBlockBufs;

  if (warpIdx.y == 2) {
    asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" ::"n"(nbRegsForIOWarps));
    if (warpIdx.x == 0) {
#pragma unroll 1
      for (uint32_t idxSubSeq = 0; idxSubSeq < nbSubSeq; idxSubSeq++) {
        uint32_t const idxBuf = idxSubSeq % nbShmBufs;
        auto& bar = bars[idxBuf];
        bar.consumed.wait_parity(toParity<nbShmBufs>(idxSubSeq));
        if (warpElectSync()) {
          tma::loadLinearAsync(&shmBufs[idxBuf], &reqPartialResults[idxSubSeq].chunks[ctaRank],
                               sizeof(PartialResult::Chunk), bar.produced);
          bar.produced.arrive_tx(sizeof(PartialResult::Chunk));
        }
      }
    }
  } else {
    asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" ::"n"(nbRegsForMathWarps));
    constexpr uint32_t nbMathWarps = 8;
    constexpr uint32_t rowsPerWarp = exactDiv(PartialResult::nbRowsPerChunk, nbMathWarps);
    constexpr uint32_t regGrainsPerRow = exactDiv(sizeof(OutputHead), grainBytes * warp_size);
    constexpr uint32_t grainOutElems = exactDiv(grainBytes, sizeof(OutputElem));
    uint32_t const lane = laneId();

    uint32_t const tileRowBase = rowsPerWarp * warpRank;
    using RowWise = Vec<float, rowsPerWarp>;
    using RegChunk = Array2D<Vec<OutputElem, grainOutElems>, rowsPerWarp, regGrainsPerRow>;
    auto loadBuf = [&](RowWise& rowMaxLog2e, RowWise& rowSum, RegChunk& regChunk,
                       PartialResult::Chunk const& chunk) {
      auto loadRowWise = [&](Vec<float, PartialResult::nbRowsPerChunk> const& src) {
        return reinterpret_cast<RowWise const&>(src[tileRowBase]);
      };
      rowMaxLog2e = loadRowWise(chunk.rowMaxLog2e);
      rowSum = loadRowWise(chunk.rowSum);
      regChunk;
#pragma unroll
      for (uint32_t i = 0; i < rowsPerWarp; i++) {
#pragma unroll
        for (uint32_t j = 0; j < regGrainsPerRow; j++) {
          regChunk(i, j) = reinterpret_cast<Vec<OutputElem, grainOutElems> const&>(
              chunk.data[tileRowBase + i][grainOutElems * (warp_size * j + lane)]);
        }
      }
    };

    uint32_t const idxSubSeqInit = 0;
    uint32_t const idxBufInit = idxSubSeqInit % nbShmBufs;
    bars[idxBufInit].produced.wait_parity(toParity<nbShmBufs>(idxSubSeqInit));
    RowWise accRowMaxLog2e;
    RowWise accRowSum;
    RegChunk chunk;
    loadBuf(accRowMaxLog2e, accRowSum, chunk, shmBufs[idxBufInit]);
    bars[idxBufInit].consumed.arrive();

    using Acc = Array2D<Vec<float, grainOutElems>, rowsPerWarp, regGrainsPerRow>;
    Acc acc;
#pragma unroll
    for (uint32_t i = 0; i < rowsPerWarp; i++) {
#pragma unroll
      for (uint32_t j = 0; j < regGrainsPerRow; j++) {
        acc(i, j) = convert<float>(chunk(i, j)) * accRowSum[i];
      }
    }

#pragma unroll 1
    for (uint32_t idxSubSeq = idxSubSeqInit + 1; idxSubSeq < nbSubSeq; idxSubSeq++) {
      uint32_t const idxBuf = idxSubSeq % nbShmBufs;
      auto& bar = bars[idxBuf];
      bar.produced.wait_parity(toParity<nbShmBufs>(idxSubSeq));
      RowWise chunkRowMaxLog2e;
      RowWise chunkRowSum;
      loadBuf(chunkRowMaxLog2e, chunkRowSum, chunk, shmBufs[idxBuf]);
      bar.consumed.arrive();
#pragma unroll
      for (uint32_t i = 0; i < rowsPerWarp; i++) {
        bool const newChunkGreater = (chunkRowMaxLog2e[i] > accRowMaxLog2e[i]);
        if (newChunkGreater) {
          float const scale = exp2f(accRowMaxLog2e[i] - chunkRowMaxLog2e[i]);
#pragma unroll
          for (uint32_t j = 0; j < regGrainsPerRow; j++) {
            acc(i, j) = acc(i, j) * scale + convert<float>(chunk(i, j)) * chunkRowSum[i];
          }
          accRowSum[i] = accRowSum[i] * scale + chunkRowSum[i];
          accRowMaxLog2e[i] = chunkRowMaxLog2e[i];
        } else {
          float const scale = exp2f(chunkRowMaxLog2e[i] - accRowMaxLog2e[i]);
          float const fusedScale = scale * chunkRowSum[i];
#pragma unroll
          for (uint32_t j = 0; j < regGrainsPerRow; j++) {
            acc(i, j) = acc(i, j) + convert<float>(chunk(i, j)) * fusedScale;
          }
          accRowSum[i] = accRowSum[i] + chunkRowSum[i] * scale;
        }
      }
    }

#pragma unroll
    for (uint32_t i = 0; i < rowsPerWarp; i++) {
      float const scale = 1.F / accRowSum[i];
      auto const dstHead = reinterpret_cast<Vec<OutputElem, grainOutElems>*>(&dst[tileRowBase + i]);
#pragma unroll
      for (uint32_t j = 0; j < regGrainsPerRow; j++) {
        dstHead[warp_size * j + lane] = convert<OutputElem>(acc(i, j) * scale);
      }
    }
  }
}

inline constexpr uint32_t cgaSize = nbProducerCtasPerCga + nbVSplit;

CUBIN_EXPORT __global__
__launch_bounds__(32 * 4 * 3, 1) __cluster_dims__(cgaSize, 1, 1) void kernel_mha(
    __grid_constant__ CUtensorMap const tensorMapQ,  // MhaIOHead[nbQHeads * totalNbInputTokens],
    __grid_constant__ CUtensorMap const tensorMapK,  // with box=64 for the least significant dim
    __grid_constant__ CUtensorMap const tensorMapV,  // with box=128 for the least significant dim
    float const qScale, float const* qScalePtr,
    OutputHead* __restrict__ const output,  // [totalNbIntputTokens][nbQHeads]
    KVCacheList<usePagedKVCache> const cacheList, uint32_t const batchSize, float kvCacheScale,
    float const* kvScalePtr,  // Same scale for K and V cache. Used only for int8/fp8 KV cache.
    Vec<CgaXBuffer,
        nbProducerCtasPerCga>* __restrict__ const cgaXBuf,  // [totalNbInputTokens][maxNbSubSeq]
    uint32_t* __restrict__ const semaphores = nullptr,      // [totalNbInputTokens]
    PartialResult* __restrict__ const partialResults =
        nullptr)  // [totalNbInputTokens][maxNbSubSeq]
{
  float const qScaleValue = qScalePtr != nullptr ? *qScalePtr : qScale;
  float const kvCacheScaleValue = kvScalePtr != nullptr ? *kvScalePtr : kvCacheScale;
  assert(blockDim.x == 32 * 12 && blockDim.y == 1 && blockDim.z == 1);
  extern __shared__ char smemBuf[];
  uint32_t const warpRank = makeWarpUniform(this_warp(), threadIdx.x / warp_size);
  uint2 const warpIdx = {warpRank % 4, warpRank / 4};

  uint3 const& cgaId = clusterId();
  uint32_t const& idxReq = cgaId.z;
  uint32_t const& maxNbSubSeq = nbClusters().y;
  uint32_t const& idxSubSeq = cgaId.y;
  uint32_t const inputSeqLen =
      (allowMultipleInputTokens ? exactDiv(gridDim.x, cgaSize)
                                : checkedVal(1U, exactDiv(gridDim.x, cgaSize)));
  uint32_t const reqIdxInputToken =
      (allowMultipleInputTokens ? blockIdx.x / cgaSize : checkedVal(0U, blockIdx.x / cgaSize));
  uint32_t const idxInputTokenGlobal = inputSeqLen * idxReq + reqIdxInputToken;
  uint32_t const cacheSeqLen = cacheList.seqLenList[idxReq] - (inputSeqLen - 1) + reqIdxInputToken;
  assert(beamWidth == 1);
  uint32_t const nbTiles = useKVCache ? divUp(cacheSeqLen, tokensPerTile) : 0;
  bool const isMultiBlockMode = (maxNbSubSeq > 1 && nbTiles >= multiBlockMinNbTiles);
  uint32_t const nbSubSeq =
      isMultiBlockMode ? mha::min(nbTiles / multiBlockMinNbTilesPerCta, maxNbSubSeq) : 1;
  static_assert(multiBlockMinNbTiles >= multiBlockMinNbTilesPerCta * 2);
  assert(isMultiBlockMode == (nbSubSeq > 1));
  if (idxSubSeq >= nbSubSeq) {
    return;
  }

  uint32_t const ctaRank = clusterCtaRank();
  bool const isProducer = (ctaRank < nbProducerCtasPerCga);

  KernelArgs const args{tensorMapQ, tensorMapK, tensorMapV,    qScaleValue,
                        output,     cacheList,  batchSize,     kvCacheScaleValue,
                        cgaXBuf,    semaphores, partialResults};

  if (isProducer) {
    Producer{args,
             *reinterpret_cast<SharedMemA*>(smemBuf),
             maxNbSubSeq,
             idxReq,
             idxInputTokenGlobal,
             cacheSeqLen,
             nbSubSeq,
             idxSubSeq,
             ctaRank,
             warpRank,
             warpIdx}
        .run();
  } else {
    Consumer{args,
             *reinterpret_cast<SharedMemB*>(smemBuf),
             maxNbSubSeq,
             idxReq,
             idxInputTokenGlobal,
             cacheSeqLen,
             nbSubSeq,
             idxSubSeq,
             ctaRank,
             warpRank,
             warpIdx}
        .run();
  }
}

__constant__ constexpr uint32_t smemSize = mha::max(sizeof(SharedMemA), sizeof(SharedMemB));
static_assert(smemSize <= 99 * 1024, "Shared memory size exceeded");
#endif  // is_MLA

#ifndef GENERATE_CUBIN
#if IS_MLA
CUtensorMap makeTensorMapForQ(void const* addr, CUtensorMapDataType_enum dataType,
                              uint32_t headElems, uint32_t totalNbHeads, uint32_t partElems) {
  CUtensorMap tensorMap{};
  uint64_t const globalDims[] = {headElems, totalNbHeads};
  uint32_t elemBytes = getElemBytes(dataType);
  uint32_t const headBytes = elemBytes * headElems;
  uint64_t const globalStrides[] = {headBytes};
  uint32_t const boxDims[] = {partElems, headGrpSize};
  uint32_t const elemStrides[] = {1, 1};
  auto const swizzle = CU_TENSOR_MAP_SWIZZLE_64B;

  checkCu(cuTensorMapEncodeTiled(&tensorMap, dataType, 2, const_cast<void*>(addr), globalDims,
                                 globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
                                 swizzle, CU_TENSOR_MAP_L2_PROMOTION_NONE,
                                 CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
  return tensorMap;
}
#endif  // IS_MLA

void launchMLA(
    cudaDeviceProp const& prop,
    uint32_t inputSeqLen,  // uniform for all requests and causal mask is assumed
    float qScale, float const* qScalePtr, OutputHead* output, InputHead const* q,
    GMemCacheHead* kCacheVLLM,                // K cache pool for VLLM layout
    GMemCacheHead* vCacheVLLM,                // V cache pool for VLLM layout
    KVCachePageIndex const* kvCachePageList,  // device pointer. shape:
                                              // [batchSize][maxNbPagesPerSeq] (Layout 1)
    uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, float kvCacheScale,
    float const* kvScalePtr,  // Same scale for K and V cache. Used only for int8/fp8 KV cache.
    uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page,
    uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream) {
#if IS_MLA
  static_assert(
      SLIDING_WINDOW == 0 && LOW_PREC_OUTPUT == 0 && USE_INPUT_KV == 0 && USE_BEAM_SEARCH == 0,
      "not implemented");
  if (beamWidth != 1) {
    throw std::runtime_error("not implemented");
  }
  static uint32_t const hostSmemSize = [&]() {
    // printf("smemSize = %u\n", smemSize);
    uint32_t size;
    checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize)));
    checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size));
    return size;
  }();
  uint32_t const nbKHeads = 1;
  uint32_t const nbVHeads = nbKHeads;
  uint32_t const nbQHeads = nbKHeads * headGrpSize;
  uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads;
  uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t {
    auto const env = std::getenv("XQA_NB_SUB_SEQ");
    if (env != nullptr) {
      int32_t const val = std::stoi(env);
      if (val > 0) {
        return val;
      }
    }
    float const factor = 4.f;
    return mha::min<uint32_t>(
        mha::max<uint32_t>(
            1U, (uint32_t)round(prop.multiProcessorCount / 4 / (batchSize * nbKHeads) * factor)),
        divUp(maxSeqLen, tokensPerTile * 2));
  }();
  // printf("nbSubSeqPerSeq = %u\n", nbSubSeqPerSeq);
  // gridDim.z == nbKHeads * batchSize && gridDim.y == nbSubSeqPerSeq && gridDim.x ==
  // nbInputSeqSplit
  dim3 const dimGrid{4 * inputSeqLen, nbSubSeqPerSeq, nbKHeads * batchSize};
  dim3 const dimCta{warp_size * 4 * 3, 1, 1};
  auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl);
  uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);
  KVCacheList<true> const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen,
                                    maxNbPagesPerSeq};
  auto const dtype = [] {
    if (std::is_same_v<CacheElem, half>) {
      return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
    } else if (std::is_same_v<CacheElem, __nv_bfloat16>) {
      return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
    } else if (std::is_same_v<CacheElem, __nv_fp8_e4m3>) {
      return CU_TENSOR_MAP_DATA_TYPE_UINT8;
    }
    throw std::runtime_error("unsupported cache element type");
  }();

  auto const tensorMapQ = makeTensorMapForQ(q, dtype, validElemsPerHead,
                                            headGrpSize * inputSeqLen * batchSize, partElemsK);
  auto const tensorMapK = makeTensorMapForPagedKVCache(
      kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsK, tokensPerTile,
      kv_stride_page, kv_stride_token, kv_stride_head);
  auto const tensorMapV = makeTensorMapForPagedKVCache(
      vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsV, tokensPerTile,
      kv_stride_page, kv_stride_token, kv_stride_head);

  uint32_t const nbCgas = exactDiv(dimGrid.x, 4) * dimGrid.y * dimGrid.z;
  auto const cgaXBuf = static_cast<Vec<CgaXBuffer, nbProducerCtasPerCga>*>(scratch);
  auto const partialResults = reinterpret_cast<PartialResult*>(cgaXBuf + nbCgas);
  cudaError_t const err = cudaLaunchKernelEx(
      &launchCfg, &kernel_mha, tensorMapQ, tensorMapK, tensorMapV, qScale, qScalePtr, output,
      cacheList, batchSize, kvCacheScale, kvScalePtr, cgaXBuf, semaphores, partialResults);
#else
  KVCacheList<false> const cacheList{kvCacheData, seqLen, maxSeqLen};
  static_assert(!usePagedKVCache);
  assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens);
  auto const tensorMap = makeTensorMapForContiguousKVCache(
      kvCacheData, CU_TENSOR_MAP_DATA_TYPE_UINT8, validElemsPerHead, nbKHeads, maxSeqLen, beamWidth,
      batchSize, gemm0CtaTileNbTokens);
  cudaLaunchKernelEx(&launchCfg, kernel_mha, nbKHeads,
#if SLIDING_WINDOW
                     slidingWinSize,
#endif
                     qScale, output,
#if LOW_PREC_OUTPUT
                     rcpOutScale,
#endif
#if USE_INPUT_KV
                     qkv,
#if ROPE_STYLE != 0
                     ropeCosSin,
#endif
#else
                     q,
#endif
                     cacheList,
#if USE_BEAM_SEARCH
                     beamSearchParams,
#endif
                     batchSize, kvCacheScale, tensorMap, semaphores, scratch);
#endif
  checkCuda(err);
#endif
}

static uint32_t configureKernel() {
  uint32_t size;
  cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize));
  cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size);
  return size;
}

static uint32_t const hostSmemSize = configureKernel();

void launchMLAFlashInfer(
    uint32_t multiProcessorCount,
    uint32_t inputSeqLen,  // uniform for all requests and causal mask is assumed
    float qScale, float const* qScalePtr, OutputHead* output, InputHead const* q,
    GMemCacheHead* kCacheVLLM,                // K cache pool for VLLM layout
    GMemCacheHead* vCacheVLLM,                // V cache pool for VLLM layout
    KVCachePageIndex const* kvCachePageList,  // device pointer. shape:
                                              // [batchSize][maxNbPagesPerSeq] (Layout 1)
    uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, float kvCacheScale,
    float const* kvScalePtr,  // Same scale for K and V cache. Used only for int8/fp8 KV cache.
    uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page,
    uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream) {
#if IS_MLA
  static_assert(
      SLIDING_WINDOW == 0 && LOW_PREC_OUTPUT == 0 && USE_INPUT_KV == 0 && USE_BEAM_SEARCH == 0,
      "not implemented");
  if (beamWidth != 1) {
    throw std::runtime_error("not implemented");
  }
  uint32_t const nbKHeads = 1;
  uint32_t const nbVHeads = nbKHeads;
  uint32_t const nbQHeads = nbKHeads * headGrpSize;
  uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads;
  /*uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t {
    float const factor = 4.f;
    return mha::min<uint32_t>(
        mha::max<uint32_t>(
            1U, (uint32_t)round(multiProcessorCount / 4 / (batchSize * nbKHeads) * factor)),
        divUp(maxSeqLen, tokensPerTile * 2));
  }();*/ // MLA disables multi-block mode for now
  // printf("nbSubSeqPerSeq = %u\n", nbSubSeqPerSeq);
  // gridDim.z == nbKHeads * batchSize && gridDim.y == nbSubSeqPerSeq && gridDim.x ==
  // nbInputSeqSplit
  dim3 const dimGrid{4 * inputSeqLen, 1, nbKHeads * batchSize};
  dim3 const dimCta{warp_size * 4 * 3, 1, 1};
  auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl);
  uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);
  KVCacheList<true> const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen,
                                    maxNbPagesPerSeq};
  auto const dtype = [] {
    if (std::is_same_v<CacheElem, half>) {
      return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
    } else if (std::is_same_v<CacheElem, __nv_bfloat16>) {
      return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
    } else if (std::is_same_v<CacheElem, __nv_fp8_e4m3>) {
      return CU_TENSOR_MAP_DATA_TYPE_UINT8;
    }
    throw std::runtime_error("unsupported cache element type");
  }();

  auto const tensorMapQ = makeTensorMapForQ(q, dtype, validElemsPerHead,
                                            headGrpSize * inputSeqLen * batchSize, partElemsK);
  auto const tensorMapK = makeTensorMapForPagedKVCache(
      kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsK, tokensPerTile,
      kv_stride_page, kv_stride_token, kv_stride_head);
  auto const tensorMapV = makeTensorMapForPagedKVCache(
      vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsV, tokensPerTile,
      kv_stride_page, kv_stride_token, kv_stride_head);

  uint32_t const nbCgas = exactDiv(dimGrid.x, 4) * dimGrid.y * dimGrid.z;
  auto const cgaXBuf = static_cast<Vec<CgaXBuffer, nbProducerCtasPerCga>*>(scratch);
  auto const partialResults = reinterpret_cast<PartialResult*>(cgaXBuf + nbCgas);
  cudaError_t const err = cudaLaunchKernelEx(
      &launchCfg, &kernel_mha, tensorMapQ, tensorMapK, tensorMapV, qScale, qScalePtr, output,
      cacheList, batchSize, kvCacheScale, kvScalePtr, cgaXBuf, semaphores, partialResults);
  checkCuda(err);
#endif
}
